自定义训练循环

Note

在极少数情况下,fit()方法可能不够灵活而无法满足需要,这时就需要自定义训练循环。

准备工作

用例子来说明

import tensorflow as tf
from tensorflow  import keras
import utils

# 获取模型和数据集
model = keras.models.load_model("my_housing_model")
(X_train, y_train), (X_val, y_val), (X_test, y_test) = utils.load_california_housing()
import numpy as np


def random_batch(X, y, batch_size=32):
    # 获得随机的batch
    idx = np.random.randint(len(X), size=batch_size)
    return X[idx], y[idx]
def print_status_bar(iteration, total, loss, metrics=None):
    # 打印损失和指标
    metrics = " - ".join(["{}: {:.4f}".format(m.name, m.result())
                         for m in [loss] + (metrics or [])])
    # 每个epoch换一行
    end = "" if iteration < total else "\n"
    # \r将光标移到行首
    print("\r{}/{} - ".format(iteration, total) + metrics, end=end)
n_epochs = 5
batch_size = 32
n_steps = len(X_train) // batch_size
# 优化器,损失函数,平均holder,指标
optimizer = keras.optimizers.Nadam(learning_rate=0.01)
loss_fn = keras.losses.mean_squared_error
mean_loss = keras.metrics.Mean()
metrics = [keras.metrics.MeanAbsoluteError()]

循环

for epoch in range(1, n_epochs + 1):
    print("Epoch {}/{}".format(epoch, n_epochs))
    for step in range(1, n_steps + 1):
        # get batch
        X_batch, y_batch = random_batch(X_train, y_train)
        # 计算损失,GradientTape()内自动微分
        with tf.GradientTape() as tape:
            y_pred = model(X_batch)
            main_loss = tf.reduce_mean(loss_fn(y_batch, y_pred))
            # 加上如正则化之类的结构损失
            loss = tf.add_n([main_loss] + model.losses)
        # 反向传播
        gradients = tape.gradient(loss, model.trainable_variables)
        # 梯度下降
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))
        # 记录平均损失
        mean_loss(loss)
        # 记录指标
        for metric in metrics:
            metric(y_batch, y_pred)
        # 打印损失和指标
        print_status_bar(step * batch_size, len(y_train), mean_loss, metrics)
    # epoch末打印
    print_status_bar(len(y_train), len(y_train), mean_loss, metrics)
    # 重置
    for metric in [mean_loss] + metrics:
        metric.reset_states()
Epoch 1/5
11610/11610 - mean: 1.5009 - mean_absolute_error: 0.9104
Epoch 2/5
11610/11610 - mean: 2.5215 - mean_absolute_error: 0.9303
Epoch 3/5
11610/11610 - mean: 1.6988 - mean_absolute_error: 0.9081
Epoch 4/5
11610/11610 - mean: 1.5036 - mean_absolute_error: 0.8840
Epoch 5/5
11610/11610 - mean: 1.3136 - mean_absolute_error: 0.8738